-
Notifications
You must be signed in to change notification settings - Fork 12.3k
llama : support Jamba hybrid Transformer-Mamba models #7531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.
* llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot
llama.cpp
Outdated
switch (hparams.n_layer) { | ||
// TODO: Jamba layers are a bit heterogenous, so naming this is hard. | ||
case 12: // 900M 8x???M | ||
case 32: // 51B 16x?B | ||
default: model.type = e_model::MODEL_UNKNOWN; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what model size type(s) I should give to Jamba models.
Great job! Works for me too, it's very fast. There were some warnings during compilation, but nothing major.
|
Amazing work!
|
ggml.c
Outdated
if (n_rs > 1) { | ||
// multiple sequences means it's hard to know when it's the first time a state is read, | ||
// so copy them all over to the destination, just to be sure. | ||
for (int i3 = 0; i3 < n_kv; ++i3) { | ||
for (int i3 = 0; i3 < n_rs; ++i3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking at adding the missing Metal kernels for SSM_CONV
and SSM_SCAN
. I'm wondering if this part of the kernels where we copy src0
-> dst
could be extracted outside of the operation via ggml_cpy
+ ggml_view
or ggml_acc
? Would simplify the implementation
Also, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using ggml_conv
, ggml_mul_mat
, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if this part of the kernels where we copy
src0
->dst
could be extracted outside of the operation viaggml_cpy
+ggml_view
orggml_acc
? Would simplify the implementation
Yes, this is definitely possible. I'll find a way to extract the copies outside.
if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.
For SSM_SCAN
, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators like SOFT_PLUS
and EXP
would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.
For simplifying SSM_CONV
, I don't think ggml_conv
supports working on independent 1D rolling windows with varying sequence lengths.
When working on a single sequence, though, it's quite simple to do the equivalent of ggml_ssm_conv
with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):
Setting nb[2]
to the element size makes the view self-overlapping.
But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea that we might consider is to unfuse the n_rs
dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch
The main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).
Just throwing some thoughts that I have so far - will continue looking at the PR in the next days
Edit: I was writing this comment before I saw you posted - will take a look tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea that we might consider is to unfuse the
n_rs
dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch
Yes, this would be doable, but would make the number of compute graph nodes scale with the number of sequences. (EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)
Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.
The recurrent steps are simpler for ubatches with sequence lengths of 1
, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.
But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention.
For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.
I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of ggml_set_rows
? Not sure about the transposed V cache, though.
A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's split when making ubatches, then the number of compute graph nodes can stay constant
No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance
Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.
For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.
Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?
A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).
From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.
I'm currently working on a big refactor of how Mamba (and Jamba) works to make all sequences of a sub-batch be of the same length (initially only for models with recurrent states), and to make recurrent state slots contiguous, with the goal of simplifying the SSM operations (and removing GGML_OP_SSM_CONV), so that GPU support will be much easier to implement after that.
Looking forward to this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance
It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.
Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?
This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.
Let me illustrate.
Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.
0: ################
1: #######
2: #
3: #
Splitting that into equal-length sequences would make 3 ubatches, like so:
0: #
1: #
2: #
3: #
0: ######
1: ######
0: #########
Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.
But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.
From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.
Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍
@compilade Is there any outstanding testing before merging this that I can help with? |
@gabe-l-hart There's Jamba 1.7 which was released recently, and I was meaning to test at least Jamba-Mini-1.7 to see if it works (including with Since my network is quite slow, the only way I can test it in a reasonable amount of time would be in a remote instance, but I didn't get around to do that yet (I might today). |
Alright, let's see how fast the downloads are on my CUDA box ⌛ |
Things are looking good! (NOTE: Built off of GraniteFour which includes this branch as of Setup python convert_hf_to_gguf.py ~/models/ai21labs/AI21-Jamba-Mini-1.7/
./build/bin/llama-quantize /home/ghart/models/ai21labs/AI21-Jamba-Mini-1.7/AI21-Jamba-Mini-1.7-F16.gguf Q4_K_M
./build/bin/llama-cli -m ~/models/ai21labs/AI21-Jamba-Mini-1.7/ggml-model-Q4_K_M.gguf -p "You are a helpful AI assistant" --jinja Results chat.log
|
I updated to the latest tip ( |
src/llama-model.cpp
Outdated
@@ -10009,16 +10056,15 @@ struct llm_build_mamba : public llm_graph_context { | |||
|
|||
// TODO: skip computing output earlier for unused tokens | |||
|
|||
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); | |||
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); | |||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); | |
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I didn't know about this operator. I see it was added recently (#14158). Seems useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, and whenever Vulkan supports non-contiguous input we can remove the ggml_cont
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember I've added this ggml_cont to avoid an assertion error in the metal backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the metal implementation of silu
requires it to be contiguous, however the swiglu
implementation does not. :)
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's gooo :)
@compilade I merged the latest
In case you haven't made the same changes locally, the merge resolution commit is 1334c7 followed by f8b81c to fix the hybrid input use. |
Jinx! I'll merge yours into mine again. The one outstanding question I have in this is whether we should make the order of Falcon H1 consistent between the various enum declarations and usages. On my merge resolution, I moved it so that Falcon H1 always comes directly after Falcon. |
@gabe-l-hart Thanks. I've also merged the changes here (apparently we were doing that at the same time
Hmm, you're right that usually, the order should be consistent. There might be some order dependencies between the structs on the I might tend toward this being fixed in its own PR (and then verifying the changes only move lines with |
Makes sense. I'll avoid shuffling things on GR4 and we can defer that to another PR (sidebar, I love learning new git tricks!) |
Some of the tensor names are common with Llama4
Ok, this time I think it's ready. Will merge after the CI passes. |
* origin/master: llama : support Jamba hybrid Transformer-Mamba models (ggml-org#7531) ggml : add ggml_scale_bias (ggml-org#14417)
* wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. * llama : use std::find for seq_nodes in llama_rs_cache * llama : state checkpoints for recurrent models * llama : correctly handle more edge cases for the rs cache * llama : rename many llama_kv_cache_* functions * llama : remove useless return value for some llama_cache_* functions * llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot * llama : support Jamba * llama : fix BERT inference without KV cache * convert-hf : check for unprocessed Jamba experts * convert-hf : support Mini-Jamba conversion * llama : fix Jamba quantization sanity checks * llama : sequence-length-aware batch splitting * llama : use equal-sequence-length sub-batches for recurrent models * ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch * llama : fix batch split output count for embeddings * llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. * llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. * llama : avoid copies for simple batch splits * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. * llama : fix .base() compilation error on Windows * llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL * ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors The implementation already supported it, and this makes Mamba's conv step slightly faster. * mamba : fix non-contiguous usage of ggml_silu * llama : session saving and reloading for hybrid models * convert_hf : fix Jamba conversion * llama : fix mixed signedness comparison * llama : use unused n_embd_k_gqa in k_shift This also slightly reduces the diff from the master branch * llama : begin renaming llama_past back to llama_kv_cache * llama : remove implicit recurrent state rollbacks * llama : partially apply clang-format style * convert : fix jamba conv1d shape squeezing * graph : add back hybrid memory graph input But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually). * model : add Jamba to Mamba-specific hparams printing * jamba : remove redundant nullptr initializations * model : remove unnecessary prefix for tensor loading constants Co-authored-by: Sigbjørn Skjæret <[email protected]> * model : use ggml_swiglu_split for Mamba Co-authored-by: Sigbjørn Skjæret <[email protected]> * model : make falcon-h1 use shared mamba2 layer builder * memory : avoid referring to KV in recurrent cache logs * gguf-py : avoid adding duplicate tensor mappings for Jamba Some of the tensor names are common with Llama4 --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
* wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. * llama : use std::find for seq_nodes in llama_rs_cache * llama : state checkpoints for recurrent models * llama : correctly handle more edge cases for the rs cache * llama : rename many llama_kv_cache_* functions * llama : remove useless return value for some llama_cache_* functions * llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot * llama : support Jamba * llama : fix BERT inference without KV cache * convert-hf : check for unprocessed Jamba experts * convert-hf : support Mini-Jamba conversion * llama : fix Jamba quantization sanity checks * llama : sequence-length-aware batch splitting * llama : use equal-sequence-length sub-batches for recurrent models * ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch * llama : fix batch split output count for embeddings * llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. * llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. * llama : avoid copies for simple batch splits * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. * llama : fix .base() compilation error on Windows * llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL * ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors The implementation already supported it, and this makes Mamba's conv step slightly faster. * mamba : fix non-contiguous usage of ggml_silu * llama : session saving and reloading for hybrid models * convert_hf : fix Jamba conversion * llama : fix mixed signedness comparison * llama : use unused n_embd_k_gqa in k_shift This also slightly reduces the diff from the master branch * llama : begin renaming llama_past back to llama_kv_cache * llama : remove implicit recurrent state rollbacks * llama : partially apply clang-format style * convert : fix jamba conv1d shape squeezing * graph : add back hybrid memory graph input But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually). * model : add Jamba to Mamba-specific hparams printing * jamba : remove redundant nullptr initializations * model : remove unnecessary prefix for tensor loading constants Co-authored-by: Sigbjørn Skjæret <[email protected]> * model : use ggml_swiglu_split for Mamba Co-authored-by: Sigbjørn Skjæret <[email protected]> * model : make falcon-h1 use shared mamba2 layer builder * memory : avoid referring to KV in recurrent cache logs * gguf-py : avoid adding duplicate tensor mappings for Jamba Some of the tensor names are common with Llama4 --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
This adds support for Jamba (fixes #6372). (https://arxiv.org/abs/2403.19887)
(this has been open for a while, and this description was very different originally (much broader scope), feel free to look at the edit history)
New features
llama.cpp
State checkpoints for recurrent modelsWorks best whenn_parallel
is at least 3 or 4 times the number of actual usersAllows backtracking tokens from the end of the last generation without having to reprocess the whole contextVery useful with theserver
example when trimming the stop string{model}.attention.head_count_kv
can now also be an array ofint32_t
, one value per layer0
kv heads are considered recurrent layers (Mamba, in the case of Jamba).Internal changes
build_mamba_layer
functions to a shared parent class between bothllm_build_mamba
andllm_build_jamba
.llm_graph_context::build_inp_mem_hybrid
llm_graph_input_mem_hybrid
llm_graph_input_rs
andllm_graph_input_attn_kv_unified
, and causes unnecessary duplication and overloads ofbuild_rs
andbuild_attn
.Future ideas
--parallel
to a big value while not unnecessarily limiting the context size of the clients of theserver
if there aren't many. (related to Parallelization / Batching Explanation #4130 (reply in thread))Testing
convert-hf-to-gguf.py
)main
)server
with backtrackingparallel
Example output of
jamba-900M-v0.13-KIx2
(click to expand)